/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file threaded_engine.cc
 * \brief implements base threaded engine.
 * \author Yutian Li
 */
#include <dmlc/logging.h>
#include <cassert>
#include <algorithm>
#include <condition_variable>
#include <mutex>
#include <utility>
#include "./threaded_engine.h"
#include "../common/cuda/utils.h"

namespace mxnet {
namespace engine {

#if ENGINE_DEBUG
std::atomic<std::size_t> OprBlock::counter{0};
std::atomic<std::size_t> VersionedVarBlock::counter{0};
std::atomic<std::size_t> ThreadedVar::counter{0};
std::atomic<std::size_t> ThreadedOpr::counter{0};
#endif  // ENGINE_DEBUG

ThreadedVar::ThreadedVar(VersionedVarBlock* head) : head_{head} {
#if ENGINE_DEBUG
  LOG(INFO) << __func__ << " " << ++counter;
#endif  // ENGINE_DEBUG
}

inline void ThreadedVar::AppendReadDependency(OprBlock* opr_block) {
  std::lock_guard<std::mutex> lock{mutex_};
  if (pending_write_ == nullptr) {
    // invariant: is_ready_to_read()
    CHECK_GE(num_pending_reads_, 0);
    // STATE CHANGE
    ++num_pending_reads_;
    // decrease wait counter
    opr_block->decr_wait();
  } else {
    auto&& new_var_block = VersionedVarBlock::New();
    assert(head_->next == nullptr);
    assert(head_->trigger == nullptr);
    assert(head_->write == false);
    // append things to next.
    head_->next    = new_var_block;
    head_->trigger = opr_block;
    head_          = new_var_block;
  }
}

inline void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) {
  auto&& new_var_block = VersionedVarBlock::New();
  std::lock_guard<std::mutex> lock{mutex_};
  // invariant.
  assert(head_->next == nullptr);
  assert(head_->trigger == nullptr);
  assert(head_->write == false);
  // attach to head.
  head_->next    = new_var_block;
  head_->trigger = opr_block;
  head_->write   = true;

  // check if it is ready to write
  if (pending_write_ == nullptr) {
    // invariant: is_ready_to_read()
    pending_write_ = head_;
    CHECK_GE(num_pending_reads_, 0);
    if (num_pending_reads_ == 0) {
      // STATE CHANGE
      opr_block->decr_wait();
      num_pending_reads_ = kWriteTriggered;
    }
  } else {
    CHECK_NE(num_pending_reads_, 0);
  }
  head_ = new_var_block;
}

template <typename Dispatcher>
inline void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) {
  OprBlock* trigger = nullptr;
  {
    // this is lock scope
    std::lock_guard<std::mutex> lock{mutex_};
    CHECK_GT(num_pending_reads_, 0);

    if (--num_pending_reads_ == 0) {
      if (pending_write_ != nullptr) {
        // STATE CHANGE
        trigger            = pending_write_->trigger;
        num_pending_reads_ = kWriteTriggered;
      }
    }
  }
  if (trigger != nullptr && trigger->decr_wait() == 0) {
    dispatcher(trigger);
  }
}

template <typename Dispatcher>
inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
  // this is lock scope
  VersionedVarBlock *old_pending_write, *end_of_read_chain;
  OprBlock* trigger_write = nullptr;
  {
    std::lock_guard<std::mutex> lock{mutex_};
    // invariants
    assert(head_->next == nullptr);
    assert(pending_write_ != nullptr);
    CHECK_EQ(num_pending_reads_, kWriteTriggered);

    // increment version number
    ++version_;

    // really delete
    if (to_delete_) {
      VersionedVarBlock* head = pending_write_->next;
      VersionedVarBlock::Delete(pending_write_);
      assert(head_ == head);
      VersionedVarBlock::Delete(head);
      return true;
    }
    // detach pending write
    old_pending_write = pending_write_;
    // search for chains to trigger
    end_of_read_chain = old_pending_write->next;
    // reset to 0 pending reads
    num_pending_reads_ = 0;
    while (end_of_read_chain != head_ && end_of_read_chain->write == false) {
      ++num_pending_reads_;
      end_of_read_chain = end_of_read_chain->next;
    }
    if (end_of_read_chain == head_) {
      pending_write_ = nullptr;
    } else {
      // check if there is pending reads, if not trigger write
      assert(end_of_read_chain->write == true);
      pending_write_ = end_of_read_chain;
      if (num_pending_reads_ == 0) {
        // mark write as already activated in this var
        num_pending_reads_ = kWriteTriggered;
        trigger_write      = end_of_read_chain->trigger;
      }
    }
  }
  // This is outside of lock scope
  // Be very carful, pending_write_ and num_pending_reads_
  // can change now, do not rely on these two variables.
  // The linked list \in [old_pending_write, end_of_read_chain)
  // is already detached from this Var.
  // So it is safe to modify these
  VersionedVarBlock* cur_head = old_pending_write->next;
  VersionedVarBlock::Delete(old_pending_write);
  // dispatch all the events
  while (cur_head != end_of_read_chain) {
    if (cur_head->trigger->decr_wait() == 0) {
      dispatcher(cur_head->trigger);
    }
    auto prev = cur_head;
    cur_head  = cur_head->next;
    assert(cur_head != nullptr);
    VersionedVarBlock::Delete(prev);
  }
  if (trigger_write != nullptr && trigger_write->decr_wait() == 0) {
    dispatcher(trigger_write);
  }
  return false;
}

inline void ThreadedVar::SetToDelete() {
  std::lock_guard<std::mutex> lock{mutex_};
  to_delete_ = true;
}

inline bool ThreadedVar::ready_to_read() {
  std::lock_guard<std::mutex> lock{mutex_};
  return this->is_ready_to_read();
}

inline size_t ThreadedVar::version() {
  std::lock_guard<std::mutex> lock{mutex_};
  return this->version_;
}

// implementation of threaded engine
ThreadedVar* ThreadedEngine::NewVariable() {
  return ThreadedVar::New(VersionedVarBlock::New());
}

ThreadedOpr* ThreadedEngine::NewOperator(ThreadedEngine::AsyncFn fn,
                                         std::vector<VarHandle> const& const_vars,
                                         std::vector<VarHandle> const& mutable_vars,
                                         FnProperty prop,
                                         const char* opr_name,
                                         bool wait) {
  auto ret      = ThreadedOpr::New();
  ret->opr_name = opr_name ? std::string(opr_name) : std::string();
  ret->fn       = std::move(fn);
  ret->prop     = prop;
  ret->const_vars.resize(const_vars.size());
  ret->mutable_vars.resize(mutable_vars.size());
  ret->wait = wait;
  std::transform(
      const_vars.begin(), const_vars.end(), ret->const_vars.begin(), ThreadedVar::CastFromBase);
  std::transform(mutable_vars.begin(),
                 mutable_vars.end(),
                 ret->mutable_vars.begin(),
                 ThreadedVar::CastFromBase);
  if (ENGINE_DEBUG != 0) {
    CheckDuplicate(const_vars, mutable_vars);
  }
  return ret;
}

void ThreadedEngine::CheckDuplicate(std::vector<VarHandle> const& const_vars,
                                    std::vector<VarHandle> const& mutable_vars) {
  // Check for duplicates.
  auto use                 = const_vars;
  auto mutate              = mutable_vars;
  const size_t use_size    = use.size();
  const size_t mutate_size = mutate.size();
  std::sort(use.begin(), use.end());
  std::sort(mutate.begin(), mutate.end());
  for (std::size_t i = 0; i < use_size; ++i) {
    if (i != 0 && use.at(i) == use.at(i - 1)) {
      LOG(FATAL) << "duplicate items found in `const_vars`";
    }
  }
  for (std::size_t i = 0; i < mutate_size; ++i) {
    if (i != 0 && mutate.at(i) == mutate.at(i - 1)) {
      LOG(FATAL) << "duplicate items found in `mutable_vars`";
    }
  }
  std::size_t j = 0;
  for (std::size_t i = 0; i < use_size; ++i) {
    while (j < mutate_size && mutate.at(j) < use.at(i)) {
      ++j;
    }
    if (j == mutate_size) {
      break;
    }
    if (mutate.at(j) == use.at(i)) {
      LOG(FATAL) << "duplicate items found between `const_vars` and `mutable_vars`";
    }
  }
}

void ThreadedEngine::DeleteOperator(OprHandle op) {
  ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
  std::vector<VarHandle> deps;
  deps.reserve(threaded_opr->const_vars.size() + threaded_opr->mutable_vars.size());
  deps.insert(deps.end(), threaded_opr->const_vars.begin(), threaded_opr->const_vars.end());
  deps.insert(deps.end(), threaded_opr->mutable_vars.begin(), threaded_opr->mutable_vars.end());
  this->PushAsync(
      [threaded_opr](RunContext, CallbackOnStart on_start, CallbackOnComplete on_complete) {
        on_start();
        ThreadedOpr::Delete(threaded_opr);
        on_complete();
      },
      Context::CPU(),
      {},
      deps,
      FnProperty::kDeleteVar,
      0,
      "DeleteOperator");
}

void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool profiling) {
  BulkFlush();
  ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
  if (profiling) {
    threaded_opr->opr_name =
        profiler::CustomOpProfiler::Get()->GenerateDisplayName(threaded_opr->opr_name.c_str());
  }
  OprBlock* opr_block = OprBlock::New();
  opr_block->opr      = threaded_opr;

  opr_block->wait.store(
      static_cast<int>(threaded_opr->const_vars.size() + threaded_opr->mutable_vars.size() + 1));
  opr_block->ctx       = exec_ctx;
  opr_block->priority  = priority;
  opr_block->profiling = profiling;
  ++pending_;
  // Add read dependencies.
  for (auto&& i : threaded_opr->const_vars) {
    i->AppendReadDependency(opr_block);
  }
  // Add write dependencies.
  for (auto&& i : threaded_opr->mutable_vars) {
    i->AppendWriteDependency(opr_block);
  }
  if (opr_block->decr_wait() == 0) {
    this->PushToExecute(opr_block, true);
  }
}

void ThreadedEngine::PushAsync(AsyncFn fn,
                               Context exec_ctx,
                               std::vector<VarHandle> const& const_vars,
                               std::vector<VarHandle> const& mutable_vars,
                               FnProperty prop,
                               int priority,
                               const char* opr_name,
                               bool wait) {
#if MXNET_USE_CUDA
  if (exec_ctx.dev_mask() == gpu::kDevMask) {
    if (device_count_ < 0) {
      int tmp = -1;
      cudaGetDeviceCount(&tmp);
      device_count_ = tmp;
      CHECK_GT(device_count_, 0) << "GPU usage requires at least 1 GPU";
    }
    CHECK_LT(exec_ctx.dev_id, device_count_)
        << "Invalid GPU Id: " << exec_ctx.dev_id
        << ", Valid device id should be less than device_count: " << device_count_;
  }
#endif
  const bool profiling = profiler_->IsProfiling(profiler::Profiler::kImperative);
  ThreadedOpr* opr     = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait);
  opr->temporary       = true;
  Push(opr, exec_ctx, priority, profiling);
}

void ThreadedEngine::PushSync(SyncFn exec_fn,
                              Context exec_ctx,
                              std::vector<VarHandle> const& const_vars,
                              std::vector<VarHandle> const& mutable_vars,
                              FnProperty prop,
                              int priority,
                              const char* opr_name) {
  if (!bulk_size() || prop != FnProperty::kNormal || priority) {
    this->PushAsync(
        [exec_fn](RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
          on_start();
          exec_fn(ctx);
          on_complete();
        },
        exec_ctx,
        const_vars,
        mutable_vars,
        prop,
        priority,
        opr_name);
    return;
  }

  const BulkStatus& bulk_status = *BulkStatusStore::Get();
  if (bulk_status.count && exec_ctx != bulk_status.ctx)
    BulkFlush();
  BulkAppend(exec_fn, exec_ctx, const_vars, mutable_vars);
}

void ThreadedEngine::DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) {
  ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
  this->PushAsync(
      [delete_fn, threaded_var](
          RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
        // Mark variable as orphan,
        // so during `ThreadedEngine::OnComplete` it could be recycled.
        on_start();
        threaded_var->SetToDelete();
        delete_fn(ctx);
        on_complete();
      },
      exec_ctx,
      {},
      {var},
      FnProperty::kDeleteVar,
      0,
      "DeleteVariable");
}

void ThreadedEngine::WaitForVar(VarHandle var) {
  BulkFlush();
  ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
  if (threaded_var->ready_to_read()) {
    ThrowException(threaded_var);
    return;
  }
  if (engine_info_) {
    LOG(INFO) << "Wait for " << threaded_var;
    debug_wait_var_ = threaded_var;
  }
  std::atomic<bool> done{false};
  this->PushAsync(
      [this, &done](RunContext, CallbackOnStart on_start, CallbackOnComplete on_complete) {
        on_start();
        if (engine_info_) {
          LOG(INFO) << "Sync is executed";
        }
        {
          std::unique_lock<std::mutex> lock{finished_m_};
          done.store(true);
        }
        finished_cv_.notify_all();
        if (engine_info_) {
          LOG(INFO) << "Sync is notified";
        }
        on_complete();
      },
      Context::CPU(),
      {var},
      {},
      FnProperty::kNormal,
      0,
      "WaitForVar",
      true);
  {
    std::unique_lock<std::mutex> lock{finished_m_};
    finished_cv_.wait(lock, [this, &done]() { return done.load() || kill_.load(); });
  }

  ThrowException(threaded_var);
}

void ThreadedEngine::WaitForAll() {
  BulkFlush();
  std::unique_lock<std::mutex> lock{finished_m_};
  finished_cv_.wait(lock, [this]() { return pending_.load() == 0 || kill_.load(); });
  std::exception_ptr exception_to_rethrow = nullptr;
  if (!global_exception_refs_.empty()) {
    // iterate through all exception refs
    for (const auto& global_exception_ref : global_exception_refs_) {
      // the first exception will be saved to be rethrown later
      if (*global_exception_ref != nullptr && exception_to_rethrow == nullptr) {
        exception_to_rethrow = *global_exception_ref;
      }
      // clear exceptions, WaitToRead following WaitForAll shouldn't throw
      *global_exception_ref = nullptr;
    }
    // A waitall following a waitall shouldn't throw any exceptions
    global_exception_refs_.clear();
    if (exception_to_rethrow != nullptr) {
      std::rethrow_exception(exception_to_rethrow);
    }
  }
}

inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
  bool is_temporary_opr = threaded_opr->temporary;
  // Mark complete for read variables
  for (auto&& i : threaded_opr->const_vars) {
    i->CompleteReadDependency([this](OprBlock* opr) { this->PushToExecute(opr, false); });
  }
  // Mark complete for write variables.
  for (auto&& i : threaded_opr->mutable_vars) {
    if (threaded_opr->opr_exception && *threaded_opr->opr_exception) {
      i->var_exception = threaded_opr->opr_exception;
      // add current operator exceptions to global exceptions if not already
      // added
      AddToGlobalExceptions(threaded_opr->opr_exception);
    }
    const bool debug_info = (engine_info_ && debug_wait_var_ == i);
    if (debug_info) {
      LOG(INFO) << "Complete write dep for " << i;
    }
    const bool to_delete = i->CompleteWriteDependency([this, debug_info](OprBlock* opr) {
      if (debug_info) {
        LOG(INFO) << "PushToExecute " << opr;
        debug_push_opr_ = opr;
      }
      this->PushToExecute(opr, false);
      if (debug_info) {
        LOG(INFO) << "Fin PushToExecute " << opr;
      }
    });
    if (to_delete) {
#if MXNET_USE_CUDA
      auto& sync_obj = i->sync_object;
      {
        std::lock_guard<std::mutex> l(sync_obj.mutex);
        sync_obj.reader_events.clear();
        sync_obj.writer_event.clear();
      }
#endif
      ThreadedVar::Delete(i);
    }
  }
  // The function been pushed from `ThreadedEngine::DeleteOperator`
  // could execute right after we mark all vars as complete, so if
  // threaded_opr is not temporary, its value is not reliable
  // anymore start from here.
  int npending = 0;
  {
    std::unique_lock<std::mutex> lock{finished_m_};
    npending = --pending_;
  }
  CHECK_GE(npending, 0);
  if (npending == 0) {
    // no need to grab lock when notify.
    finished_cv_.notify_all();
  }

  // delete operator if it is temperory
  if (is_temporary_opr) {
    ThreadedOpr::Delete(threaded_opr);
  }
}

inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) {
  if (threaded_var->var_exception && *threaded_var->var_exception) {
    std::exception_ptr tmp       = *threaded_var->var_exception;
    *threaded_var->var_exception = nullptr;
    std::rethrow_exception(tmp);
  }
  return;
}

void ThreadedEngine::Throw(VarHandle var) {
  ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
  ThrowException(threaded_var);
}

void ThreadedEngine::OnCompleteStatic(Engine* engine, void* opr_block_, const dmlc::Error* error) {
  OprBlock* opr_block       = static_cast<OprBlock*>(opr_block_);
  ThreadedOpr* threaded_opr = opr_block->opr;
  if (error != nullptr) {
    auto ex_p                   = std::make_exception_ptr(*error);
    threaded_opr->opr_exception = std::make_shared<std::exception_ptr>(ex_p);
  }
  if (opr_block->profiling && threaded_opr->opr_name.size()) {
    // record operator end timestamp
    opr_block->opr_profile->stop();
  }
  static_cast<ThreadedEngine*>(engine)->OnComplete(threaded_opr);
  OprBlock::Delete(opr_block);
}

void ThreadedEngine::OnStartStatic(Engine* engine, void* opr_block, const dmlc::Error* error) {
  // no-op
}

#if MXNET_USE_CUDA
static inline void AddEventHelper(std::unordered_map<cudaStream_t, EventInfo>* events_per_stream,
                                  const EventInfo& cuda_event) {
  auto event_stream = cuda_event.stream;
  if (events_per_stream->count(event_stream) > 0) {
    if ((*events_per_stream)[event_stream].pool_index < cuda_event.pool_index) {
      (*events_per_stream)[event_stream] = cuda_event;
    }
  } else {
    (*events_per_stream).emplace(event_stream, cuda_event);
  }
}

static inline bool IsEngineAsync() {
  std::string type = dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string(""));
  std::string async_engine_tag("Async");
  auto tag_pos = type.find(async_engine_tag);
  return tag_pos != std::string::npos;
}

void ThreadedEngine::OnStartCPU(Engine* engine, void* opr_block, const dmlc::Error* error) {
  static bool use_new_dep_engine = IsEngineAsync();
  if (!use_new_dep_engine) {
    return;
  }
  ThreadedOpr* threaded_opr = static_cast<OprBlock*>(opr_block)->opr;
  std::unordered_map<cudaStream_t, EventInfo> event_per_stream;
  for (auto* read_var : threaded_opr->const_vars) {
    auto& sync_obj = read_var->sync_object;
    std::lock_guard<std::mutex> l(sync_obj.mutex);
    auto& reader_events = sync_obj.reader_events;
    // check for expired events and delete them
    reader_events.erase(std::remove_if(reader_events.begin(),
                                       reader_events.end(),
                                       [&](const EventInfo e_i) { return e_i.event.expired(); }),
                        reader_events.end());
    for (auto& cuda_event : reader_events) {
      AddEventHelper(&event_per_stream, cuda_event);
    }
    if (!sync_obj.writer_event.empty()) {
      if (sync_obj.writer_event[0].event.expired()) {
        sync_obj.writer_event.clear();
      } else {
        AddEventHelper(&event_per_stream, sync_obj.writer_event[0]);
      }
    }
  }

  for (auto* write_var : threaded_opr->mutable_vars) {
    auto& sync_obj = write_var->sync_object;
    std::lock_guard<std::mutex> l(sync_obj.mutex);
    auto& reader_events = sync_obj.reader_events;
    // check for expired events and delete them
    reader_events.erase(std::remove_if(reader_events.begin(),
                                       reader_events.end(),
                                       [&](const EventInfo e_i) { return e_i.event.expired(); }),
                        reader_events.end());
    for (auto& cuda_event : reader_events) {
      AddEventHelper(&event_per_stream, cuda_event);
    }
    if (!sync_obj.writer_event.empty()) {
      if (sync_obj.writer_event[0].event.expired()) {
        sync_obj.writer_event.clear();
      } else {
        AddEventHelper(&event_per_stream, sync_obj.writer_event[0]);
      }
    }
  }
  for (auto event : event_per_stream) {
    auto ev = event.second.event.lock();
    MSHADOW_CUDA_CALL(cudaEventSynchronize(*ev));
  }
}

void ThreadedEngine::OnStartGPU(Engine* engine, void* sync_info, const dmlc::Error* error) {
  static bool use_new_dep_engine = IsEngineAsync();
  if (!use_new_dep_engine) {
    return;
  }
  auto* info = reinterpret_cast<GPUWorkerSyncInfo*>(sync_info);
  CHECK(info->stream != nullptr);
  auto* worker_stream       = reinterpret_cast<mshadow::Stream<gpu>*>(info->stream);
  ThreadedOpr* threaded_opr = static_cast<OprBlock*>(info->opr_block)->opr;
  std::unordered_map<cudaStream_t, EventInfo> event_per_stream;
  for (auto* read_var : threaded_opr->const_vars) {
    auto& sync_obj = read_var->sync_object;
    std::lock_guard<std::mutex> l(sync_obj.mutex);
    auto& reader_events = sync_obj.reader_events;
    // check for expired events and delete them
    reader_events.erase(std::remove_if(reader_events.begin(),
                                       reader_events.end(),
                                       [&](const EventInfo e_i) { return e_i.event.expired(); }),
                        reader_events.end());
    for (auto& writer : sync_obj.writer_event) {
      if (writer.event.expired()) {
        sync_obj.writer_event.clear();
        break;
      }
      if (writer.stream != worker_stream->stream_) {
        // if there is already a reader on the same stream as us,
        // it already synced with that writer and we can rely on
        // the ongoing sync
        bool found = false;
        for (const auto& reader : reader_events) {
          if (reader.stream == worker_stream->stream_) {
            found = true;
            break;
          }
        }
        if (!found) {
          AddEventHelper(&event_per_stream, writer);
        }
      }
    }
  }
  for (auto* write_var : threaded_opr->mutable_vars) {
    auto& sync_obj = write_var->sync_object;
    std::lock_guard<std::mutex> l(sync_obj.mutex);
    // check for expired events and delete them
    auto& reader_events = sync_obj.reader_events;
    reader_events.erase(std::remove_if(reader_events.begin(),
                                       reader_events.end(),
                                       [&](const EventInfo e_i) { return e_i.event.expired(); }),
                        reader_events.end());
    // if there are some readers, we wait for them
    for (auto& cuda_event : reader_events) {
      if (worker_stream->stream_ != cuda_event.stream) {
        AddEventHelper(&event_per_stream, cuda_event);
      }
    }
    if (!sync_obj.writer_event.empty()) {
      if (sync_obj.writer_event[0].event.expired()) {
        sync_obj.writer_event.clear();
      } else {
        if (worker_stream->stream_ != sync_obj.writer_event[0].stream) {
          AddEventHelper(&event_per_stream, sync_obj.writer_event[0]);
        }
      }
    }
  }
  for (auto event : event_per_stream) {
    auto ev = event.second.event.lock();
    MSHADOW_CUDA_CALL(cudaStreamWaitEvent(worker_stream->stream_, *ev, 0));
  }
}

void ThreadedEngine::OnCompleteGPU(Engine* engine, void* sync_info, const dmlc::Error* error) {
  auto* info = reinterpret_cast<GPUWorkerSyncInfo*>(sync_info);
  CHECK(info->stream != nullptr);

  auto* worker_stream            = reinterpret_cast<mshadow::Stream<gpu>*>(info->stream);
  static bool use_new_dep_engine = IsEngineAsync();

  if (!use_new_dep_engine) {
    worker_stream->Wait();
    ThreadedEngine::OnCompleteStatic(engine, info->opr_block, error);
    GPUWorkerSyncInfo::Delete(info);
    return;
  }

  ThreadedOpr* threaded_opr    = static_cast<OprBlock*>(info->opr_block)->opr;
  auto* event_pool             = static_cast<CUDAEventPool*>(info->event_pool);
  auto [event, event_pool_idx] = event_pool->GetNextEvent();  // NOLINT(*)
  auto ev                      = event.lock();
  MSHADOW_CUDA_CALL(cudaEventRecord(*ev, worker_stream->stream_));
  for (auto* read_var : threaded_opr->const_vars) {
    auto& sync_obj = read_var->sync_object;
    std::lock_guard<std::mutex> l(sync_obj.mutex);
    // If some reader event is already recorded on the same stream,
    // we want to replace ourselves by it
    int i;
    for (i = 0; i < sync_obj.reader_events.size(); ++i) {
      auto stream = sync_obj.reader_events[i].stream;
      if (stream == worker_stream->stream_) {
        sync_obj.reader_events[i].event      = event;
        sync_obj.reader_events[i].pool_index = event_pool_idx;
        break;
      }
    }
    if (i == sync_obj.reader_events.size()) {
      sync_obj.reader_events.push_back({event, worker_stream->stream_, event_pool_idx});
    }
  }

  for (auto* write_var : threaded_opr->mutable_vars) {
    auto& sync_obj = write_var->sync_object;
    std::lock_guard<std::mutex> l(sync_obj.mutex);
    sync_obj.reader_events.clear();
    sync_obj.writer_event.clear();
    sync_obj.writer_event.push_back({event, worker_stream->stream_, event_pool_idx});
  }

  ThreadedEngine::OnCompleteStatic(engine, info->opr_block, error);
  GPUWorkerSyncInfo::Delete(info);
}
#endif

}  // namespace engine
}  // namespace mxnet
